#include "Vocabulary.h"
#include "DescManip.h"

#include <vector>
#include <numeric>
#include <fstream>
#include <string>
#include <algorithm>
#include <opencv2/core.hpp>

#include "FeatureVector.h"
#include "BowVector.h"
#include "ScoringObject.h"

using namespace std;
namespace DBoW3
{
// --------------------------------------------------------------------------


Vocabulary::Vocabulary
( int k, int L, WeightingType weighting, ScoringType scoring )
    : m_k ( k ), m_L ( L ), m_weighting ( weighting ), m_scoring ( scoring ),
      m_scoring_object ( NULL )
{
    createScoringObject();
}

// --------------------------------------------------------------------------


Vocabulary::Vocabulary
( const std::string &filename ) : m_scoring_object ( NULL )
{
    load ( filename );
}

// --------------------------------------------------------------------------


Vocabulary::Vocabulary
( const char *filename ) : m_scoring_object ( NULL )
{
    load ( filename );
}

// --------------------------------------------------------------------------


void Vocabulary::createScoringObject()
{
    delete m_scoring_object;
    m_scoring_object = NULL;

    switch ( m_scoring )
    {
    case L1_NORM:
        m_scoring_object = new L1Scoring;
        break;

    case L2_NORM:
        m_scoring_object = new L2Scoring;
        break;

    case CHI_SQUARE:
        m_scoring_object = new ChiSquareScoring;
        break;

    case KL:
        m_scoring_object = new KLScoring;
        break;

    case BHATTACHARYYA:
        m_scoring_object = new BhattacharyyaScoring;
        break;

    case DOT_PRODUCT:
        m_scoring_object = new DotProductScoring;
        break;

    }
}

// --------------------------------------------------------------------------


void Vocabulary::setScoringType ( ScoringType type )
{
    m_scoring = type;
    createScoringObject();
}

// --------------------------------------------------------------------------


void Vocabulary::setWeightingType ( WeightingType type )
{
    this->m_weighting = type;
}

// --------------------------------------------------------------------------


Vocabulary::Vocabulary (
    const Vocabulary &voc )
    : m_scoring_object ( NULL )
{
    *this = voc;
}

// --------------------------------------------------------------------------


Vocabulary::~Vocabulary()
{
    delete m_scoring_object;
}

// --------------------------------------------------------------------------


Vocabulary&
Vocabulary::operator=
( const Vocabulary &voc )
{
    this->m_k = voc.m_k;
    this->m_L = voc.m_L;
    this->m_scoring = voc.m_scoring;
    this->m_weighting = voc.m_weighting;

    this->createScoringObject();

    this->m_nodes.clear();
    this->m_words.clear();

    this->m_nodes = voc.m_nodes;
    this->createWords();

    return *this;
}



void Vocabulary::create (
    const std::vector< cv::Mat > &training_features )
{
    std::vector<std::vector<cv::Mat> > vtf ( training_features.size() );
    for ( int i=0; i<training_features.size(); i++ )
    {
        vtf[i].resize ( training_features[i].rows );
        for ( int r=0; r<training_features[i].rows; r++ )
            vtf[i][r]=training_features[i].rowRange ( r,r+1 );
    }
    create ( vtf );

}

void Vocabulary::create (
    const std::vector<std::vector<cv::Mat> > &training_features )
{
    m_nodes.clear();
    m_words.clear();

    // expected_nodes = Sum_{i=0..L} ( k^i )
    int expected_nodes =
        ( int ) ( ( pow ( ( double ) m_k, ( double ) m_L + 1 ) - 1 ) / ( m_k - 1 ) );

    m_nodes.reserve ( expected_nodes ); // avoid allocations when creating the tree


    std::vector<cv::Mat> features;
    getFeatures ( training_features, features );


    // create root
    m_nodes.push_back ( Node ( 0 ) ); // root

    // create the tree
    HKmeansStep ( 0, features, 1 );

    // create the words
    createWords();

    // and set the weight of each node of the tree
    setNodeWeights ( training_features );

}

// --------------------------------------------------------------------------


void Vocabulary::create (
    const std::vector<std::vector<cv::Mat> > &training_features,
    int k, int L )
{
    m_k = k;
    m_L = L;

    create ( training_features );
}

// --------------------------------------------------------------------------


void Vocabulary::create (
    const std::vector<std::vector<cv::Mat> > &training_features,
    int k, int L, WeightingType weighting, ScoringType scoring )
{
    m_k = k;
    m_L = L;
    m_weighting = weighting;
    m_scoring = scoring;
    createScoringObject();

    create ( training_features );
}

// --------------------------------------------------------------------------


void Vocabulary::getFeatures (
    const std::vector<std::vector<cv::Mat> > &training_features,
    std::vector<cv::Mat> &features ) const
{
    features.resize ( 0 );
    for ( size_t i=0; i<training_features.size(); i++ )
        for ( size_t j=0; j<training_features[i].size(); j++ )
            features.push_back ( training_features[i][j] );
}

// --------------------------------------------------------------------------


void Vocabulary::HKmeansStep ( NodeId parent_id,
                               const std::vector<cv::Mat> &descriptors, int current_level )
{

    if ( descriptors.empty() ) return;

    // features associated to each cluster
    std::vector<cv::Mat> clusters;
    std::vector<std::vector<unsigned int> > groups; // groups[i] = [j1, j2, ...]
    // j1, j2, ... indices of descriptors associated to cluster i

    clusters.reserve ( m_k );
    groups.reserve ( m_k );

    //const int msizes[] = { m_k, descriptors.size() };
    //cv::SparseMat assoc(2, msizes, CV_8U);
    //cv::SparseMat last_assoc(2, msizes, CV_8U);
    //// assoc.row(cluster_idx).col(descriptor_idx) = 1 iif associated

    if ( ( int ) descriptors.size() <= m_k )
    {
        // trivial case: one cluster per feature
        groups.resize ( descriptors.size() );

        for ( unsigned int i = 0; i < descriptors.size(); i++ )
        {
            groups[i].push_back ( i );
            clusters.push_back ( descriptors[i] );
        }
    }
    else
    {
        // select clusters and groups with kmeans

        bool first_time = true;
        bool goon = true;

        // to check if clusters move after iterations
        std::vector<int> last_association, current_association;

        while ( goon )
        {
            // 1. Calculate clusters

            if ( first_time )
            {
                // random sample
                initiateClusters ( descriptors, clusters );
            }
            else
            {
                // calculate cluster centres

                for ( unsigned int c = 0; c < clusters.size(); ++c )
                {
                    std::vector<cv::Mat> cluster_descriptors;
                    cluster_descriptors.reserve ( groups[c].size() );

                    /*
                    for(unsigned int d = 0; d < descriptors.size(); ++d)
                    {
                      if( assoc.find<unsigned char>(c, d) )
                      {
                        cluster_descriptors.push_back(descriptors[d]);
                      }
                    }
                    */

                    std::vector<unsigned int>::const_iterator vit;
                    for ( vit = groups[c].begin(); vit != groups[c].end(); ++vit )
                    {
                        cluster_descriptors.push_back ( descriptors[*vit] );
                    }

                    DescManip::meanValue ( cluster_descriptors, clusters[c] );
                }

            } // if(!first_time)

            // 2. Associate features with clusters

            // calculate distances to cluster centers
            groups.clear();
            groups.resize ( clusters.size(), std::vector<unsigned int>() );
            current_association.resize ( descriptors.size() );

            //assoc.clear();

            //unsigned int d = 0;
            for ( auto  fit = descriptors.begin(); fit != descriptors.end(); ++fit ) //, ++d)
            {
                double best_dist = DescManip::distance ( ( *fit ), clusters[0] );
                unsigned int icluster = 0;

                for ( unsigned int c = 1; c < clusters.size(); ++c )
                {
                    double dist = DescManip::distance ( ( *fit ), clusters[c] );
                    if ( dist < best_dist )
                    {
                        best_dist = dist;
                        icluster = c;
                    }
                }

                //assoc.ref<unsigned char>(icluster, d) = 1;

                groups[icluster].push_back ( fit - descriptors.begin() );
                current_association[ fit - descriptors.begin() ] = icluster;
            }

            // kmeans++ ensures all the clusters has any feature associated with them

            // 3. check convergence
            if ( first_time )
            {
                first_time = false;
            }
            else
            {
                //goon = !eqUChar(last_assoc, assoc);

                goon = false;
                for ( unsigned int i = 0; i < current_association.size(); i++ )
                {
                    if ( current_association[i] != last_association[i] )
                    {
                        goon = true;
                        break;
                    }
                }
            }

            if ( goon )
            {
                // copy last feature-cluster association
                last_association = current_association;
                //last_assoc = assoc.clone();
            }

        } // while(goon)

    } // if must run kmeans

    // create nodes
    for ( unsigned int i = 0; i < clusters.size(); ++i )
    {
        NodeId id = m_nodes.size();
        m_nodes.push_back ( Node ( id ) );
        m_nodes.back().descriptor = clusters[i];
        m_nodes.back().parent = parent_id;
        m_nodes[parent_id].children.push_back ( id );
    }

    // go on with the next level
    if ( current_level < m_L )
    {
        // iterate again with the resulting clusters
        const std::vector<NodeId> &children_ids = m_nodes[parent_id].children;
        for ( unsigned int i = 0; i < clusters.size(); ++i )
        {
            NodeId id = children_ids[i];

            std::vector<cv::Mat> child_features;
            child_features.reserve ( groups[i].size() );

            std::vector<unsigned int>::const_iterator vit;
            for ( vit = groups[i].begin(); vit != groups[i].end(); ++vit )
            {
                child_features.push_back ( descriptors[*vit] );
            }

            if ( child_features.size() > 1 )
            {
                HKmeansStep ( id, child_features, current_level + 1 );
            }
        }
    }
}

// --------------------------------------------------------------------------


void Vocabulary::initiateClusters
( const std::vector<cv::Mat> &descriptors,
  std::vector<cv::Mat> &clusters ) const
{
    initiateClustersKMpp ( descriptors, clusters );
}

// --------------------------------------------------------------------------


void Vocabulary::initiateClustersKMpp (
    const std::vector<cv::Mat> &pfeatures,
    std::vector<cv::Mat> &clusters ) const
{
    // Implements kmeans++ seeding algorithm
    // Algorithm:
    // 1. Choose one center uniformly at random from among the data points.
    // 2. For each data point x, compute D(x), the distance between x and the nearest
    //    center that has already been chosen.
    // 3. Add one new data point as a center. Each point x is chosen with probability
    //    proportional to D(x)^2.
    // 4. Repeat Steps 2 and 3 until k centers have been chosen.
    // 5. Now that the initial centers have been chosen, proceed using standard k-means
    //    clustering.


//  DUtils::Random::SeedRandOnce();

    clusters.resize ( 0 );
    clusters.reserve ( m_k );
    std::vector<double> min_dists ( pfeatures.size(), std::numeric_limits<double>::max() );

    // 1.

    int ifeature = rand() % pfeatures.size(); //DUtils::Random::RandomInt(0, pfeatures.size()-1);

    // create first cluster
    clusters.push_back ( pfeatures[ifeature] );

    // compute the initial distances
    std::vector<double>::iterator dit;
    dit = min_dists.begin();
    for ( auto fit = pfeatures.begin(); fit != pfeatures.end(); ++fit, ++dit )
    {
        *dit = DescManip::distance ( ( *fit ), clusters.back() );
    }

    while ( ( int ) clusters.size() < m_k )
    {
        // 2.
        dit = min_dists.begin();
        for ( auto  fit = pfeatures.begin(); fit != pfeatures.end(); ++fit, ++dit )
        {
            if ( *dit > 0 )
            {
                double dist = DescManip::distance ( ( *fit ), clusters.back() );
                if ( dist < *dit ) *dit = dist;
            }
        }

        // 3.
        double dist_sum = std::accumulate ( min_dists.begin(), min_dists.end(), 0.0 );

        if ( dist_sum > 0 )
        {
            double cut_d;
            do
            {

                cut_d = ( double ( rand() ) / double ( RAND_MAX ) ) * dist_sum;
            }
            while ( cut_d == 0.0 );

            double d_up_now = 0;
            for ( dit = min_dists.begin(); dit != min_dists.end(); ++dit )
            {
                d_up_now += *dit;
                if ( d_up_now >= cut_d ) break;
            }

            if ( dit == min_dists.end() )
                ifeature = pfeatures.size()-1;
            else
                ifeature = dit - min_dists.begin();


            clusters.push_back ( pfeatures[ifeature] );
        } // if dist_sum > 0
        else
            break;

    } // while(used_clusters < m_k)

}

// --------------------------------------------------------------------------


void Vocabulary::createWords()
{
    m_words.resize ( 0 );

    if ( !m_nodes.empty() )
    {
        m_words.reserve ( ( int ) pow ( ( double ) m_k, ( double ) m_L ) );


        auto  nit = m_nodes.begin(); // ignore root
        for ( ++nit; nit != m_nodes.end(); ++nit )
        {
            if ( nit->isLeaf() )
            {
                nit->word_id = m_words.size();
                m_words.push_back ( & ( *nit ) );
            }
        }
    }
}

// --------------------------------------------------------------------------


void Vocabulary::setNodeWeights
( const std::vector<std::vector<cv::Mat> > &training_features )
{
    const unsigned int NWords = m_words.size();
    const unsigned int NDocs = training_features.size();

    if ( m_weighting == TF || m_weighting == BINARY )
    {
        // idf part must be 1 always
        for ( unsigned int i = 0; i < NWords; i++ )
            m_words[i]->weight = 1;
    }
    else if ( m_weighting == IDF || m_weighting == TF_IDF )
    {
        // IDF and TF-IDF: we calculte the idf path now

        // Note: this actually calculates the idf part of the tf-idf score.
        // The complete tf-idf score is calculated in ::transform

        std::vector<unsigned int> Ni ( NWords, 0 );
        std::vector<bool> counted ( NWords, false );


        for ( auto mit = training_features.begin(); mit != training_features.end(); ++mit )
        {
            fill ( counted.begin(), counted.end(), false );

            for ( auto fit = mit->begin(); fit < mit->end(); ++fit )
            {
                WordId word_id;
                transform ( *fit, word_id );

                if ( !counted[word_id] )
                {
                    Ni[word_id]++;
                    counted[word_id] = true;
                }
            }
        }

        // set ln(N/Ni)
        for ( unsigned int i = 0; i < NWords; i++ )
        {
            if ( Ni[i] > 0 )
            {
                m_words[i]->weight = log ( ( double ) NDocs / ( double ) Ni[i] );
            }// else // This cannot occur if using kmeans++
        }

    }

}

// --------------------------------------------------------------------------






// --------------------------------------------------------------------------


float Vocabulary::getEffectiveLevels() const
{
    long sum = 0;
    for ( auto wit = m_words.begin(); wit != m_words.end(); ++wit )
    {
        const Node *p = *wit;

        for ( ; p->id != 0; sum++ ) p = &m_nodes[p->parent];
    }

    return ( float ) ( ( double ) sum / ( double ) m_words.size() );
}

// --------------------------------------------------------------------------


cv::Mat Vocabulary::getWord ( WordId wid ) const
{
    return m_words[wid]->descriptor;
}

// --------------------------------------------------------------------------


WordValue Vocabulary::getWordWeight ( WordId wid ) const
{
    return m_words[wid]->weight;
}

// --------------------------------------------------------------------------


WordId Vocabulary::transform
( const cv::Mat& feature ) const
{
    if ( empty() )
    {
        return 0;
    }

    WordId wid;
    transform ( feature, wid );
    return wid;
}

// --------------------------------------------------------------------------

void Vocabulary::transform (
    const cv::Mat& features, BowVector &v ) const
{
    std::vector<cv::Mat> vf ( features.rows );
    for ( int r=0; r<features.rows; r++ ) vf[r]=features.rowRange ( r,r+1 );
    transform ( vf,v );
}

void Vocabulary::transform (
    const std::vector<cv::Mat>& features, BowVector &v ) const
{
    v.clear();

    if ( empty() )
    {
        return;
    }

    // normalize
    LNorm norm;
    bool must = m_scoring_object->mustNormalize ( norm );


    if ( m_weighting == TF || m_weighting == TF_IDF )
    {
        for ( auto fit = features.begin(); fit < features.end(); ++fit )
        {
            WordId id;
            WordValue w;
            // w is the idf value if TF_IDF, 1 if TF

            transform ( *fit, id, w );

            // not stopped
            if ( w > 0 ) v.addWeight ( id, w );
        }

        if ( !v.empty() && !must )
        {
            // unnecessary when normalizing
            const double nd = v.size();
            for ( BowVector::iterator vit = v.begin(); vit != v.end(); vit++ )
                vit->second /= nd;
        }

    }
    else // IDF || BINARY
    {
        for ( auto fit = features.begin(); fit < features.end(); ++fit )
        {
            WordId id;
            WordValue w;
            // w is idf if IDF, or 1 if BINARY

            transform ( *fit, id, w );

            // not stopped
            if ( w > 0 ) v.addIfNotExist ( id, w );

        } // if add_features
    } // if m_weighting == ...

    if ( must ) v.normalize ( norm );
}

// --------------------------------------------------------------------------


void Vocabulary::transform (
    const std::vector<cv::Mat>& features,
    BowVector &v, FeatureVector &fv, int levelsup ) const
{
    v.clear();
    fv.clear();

    if ( empty() ) // safe for subclasses
    {
        return;
    }

    // normalize
    LNorm norm;
    bool must = m_scoring_object->mustNormalize ( norm );


    if ( m_weighting == TF || m_weighting == TF_IDF )
    {
        unsigned int i_feature = 0;
        for ( auto fit = features.begin(); fit < features.end(); ++fit, ++i_feature )
        {
            WordId id;
            NodeId nid;
            WordValue w;
            // w is the idf value if TF_IDF, 1 if TF

            transform ( *fit, id, w, &nid, levelsup );

            if ( w > 0 ) // not stopped
            {
                v.addWeight ( id, w );
                fv.addFeature ( nid, i_feature );
            }
        }

        if ( !v.empty() && !must )
        {
            // unnecessary when normalizing
            const double nd = v.size();
            for ( BowVector::iterator vit = v.begin(); vit != v.end(); vit++ )
                vit->second /= nd;
        }

    }
    else // IDF || BINARY
    {
        unsigned int i_feature = 0;
        for ( auto fit = features.begin(); fit < features.end(); ++fit, ++i_feature )
        {
            WordId id;
            NodeId nid;
            WordValue w;
            // w is idf if IDF, or 1 if BINARY

            transform ( *fit, id, w, &nid, levelsup );

            if ( w > 0 ) // not stopped
            {
                v.addIfNotExist ( id, w );
                fv.addFeature ( nid, i_feature );
            }
        }
    } // if m_weighting == ...

    if ( must ) v.normalize ( norm );
}

// --------------------------------------------------------------------------


// --------------------------------------------------------------------------


void Vocabulary::transform
( const cv::Mat &feature, WordId &id ) const
{
    WordValue weight;
    transform ( feature, id, weight );
}

// --------------------------------------------------------------------------


void Vocabulary::transform ( const cv::Mat &feature,
                             WordId &word_id, WordValue &weight, NodeId *nid, int levelsup ) const
{
    // propagate the feature down the tree
    std::vector<NodeId> nodes;

    // level at which the node must be stored in nid, if given
    const int nid_level = m_L - levelsup;
    if ( nid_level <= 0 && nid != NULL ) *nid = 0; // root

    NodeId final_id = 0; // root
    int current_level = 0;

    do
    {
        ++current_level;
        nodes = m_nodes[final_id].children;
        final_id = nodes[0];

        double best_d = DescManip::distance ( feature, m_nodes[final_id].descriptor );

        for ( auto nit = nodes.begin() + 1; nit != nodes.end(); ++nit )
        {
            NodeId id = *nit;
            double d = DescManip::distance ( feature, m_nodes[id].descriptor );
            if ( d < best_d )
            {
                best_d = d;
                final_id = id;
            }
        }

        if ( nid != NULL && current_level == nid_level )
            *nid = final_id;

    }
    while ( !m_nodes[final_id].isLeaf() );

    // turn node id into word id
    word_id = m_nodes[final_id].word_id;
    weight = m_nodes[final_id].weight;
}

// --------------------------------------------------------------------------


NodeId Vocabulary::getParentNode
( WordId wid, int levelsup ) const
{
    NodeId ret = m_words[wid]->id; // node id
    while ( levelsup > 0 && ret != 0 ) // ret == 0 --> root
    {
        --levelsup;
        ret = m_nodes[ret].parent;
    }
    return ret;
}

// --------------------------------------------------------------------------


void Vocabulary::getWordsFromNode
( NodeId nid, std::vector<WordId> &words ) const
{
    words.clear();

    if ( m_nodes[nid].isLeaf() )
    {
        words.push_back ( m_nodes[nid].word_id );
    }
    else
    {
        words.reserve ( m_k ); // ^1, ^2, ...

        std::vector<NodeId> parents;
        parents.push_back ( nid );

        while ( !parents.empty() )
        {
            NodeId parentid = parents.back();
            parents.pop_back();

            const std::vector<NodeId> &child_ids = m_nodes[parentid].children;
            std::vector<NodeId>::const_iterator cit;

            for ( cit = child_ids.begin(); cit != child_ids.end(); ++cit )
            {
                const Node &child_node = m_nodes[*cit];

                if ( child_node.isLeaf() )
                    words.push_back ( child_node.word_id );
                else
                    parents.push_back ( *cit );

            } // for each child
        } // while !parents.empty
    }
}

// --------------------------------------------------------------------------


int Vocabulary::stopWords ( double minWeight )
{
    int c = 0;
    for ( auto wit = m_words.begin(); wit != m_words.end(); ++wit )
    {
        if ( ( *wit )->weight < minWeight )
        {
            ++c;
            ( *wit )->weight = 0;
        }
    }
    return c;
}

// --------------------------------------------------------------------------


void Vocabulary::save ( const std::string &filename ) const
{
    cv::FileStorage fs ( filename.c_str(), cv::FileStorage::WRITE );
    if ( !fs.isOpened() ) throw std::string ( "Could not open file " ) + filename;

    save ( fs );
}
/**
//void Vocabulary::saveToTextFile ( const std::string &filename ) const
//{
//    cv::FileStorage fs ( filename.c_str(), cv::FileStorage::WRITE );
//    if ( !fs.isOpened() ) throw std::string ( "Could not open file " ) + filename;
//
//    saveToTextFile ( fs );
//}
*/

void Vocabulary::saveToTextFile (const std::string &filename) const
{
    fstream f;
    f.open(filename.c_str(),ios_base::out);
    // 第一行打印 树的分支数、深度、评分方式、权重计算方式
    //m_k类型为int, m_L类型为int, m_scoring为枚举类型，此处的0表示相似度计算方式L1_NORM, m_weighting为枚举类型，此处的0表示权重TF_IDF;
    f << m_k << " " << m_L << " " << " " << m_scoring << " " << m_weighting << endl;

    for(size_t i=1; i<m_nodes.size();i++)
    {
        const Node& node = m_nodes[i];

        //第二行开始，每行第1个数字为父节点id
        f << node.parent << " ";
        //每行第2个数字标记是（1）否（0）为叶子（Word）
        if(node.isLeaf())
            f << 1 << " ";
        else
            f << 0 << " ";
        // 接下来存储256位描述子，最后存储节点权重（叶子的情况下才有非零值）
        f << DescManip::toString(node.descriptor) << " " << (double)node.weight << endl;
    }

    f.close();
}

// 参考： https://blog.csdn.net/weixin_39917437/article/details/111122918
// sizeof(m_nodes.at(0).parent)==4;
// F::L * sizeof(char)==32;
// sizeof(float)=4;
// sizeof(bool)==1;
// node_size== 41;
void Vocabulary::saveToBinaryFile(const std::string &filename) const
{
    fstream f;
    f.open(filename.c_str(), ios_base::out|ios::binary);
    unsigned int nb_nodes = m_nodes.size();
    float _weight;
    //unsigned int size_node = sizeof(m_nodes[0].parent) + F::L*sizeof(char) + sizeof(_weight) + sizeof(bool);
    unsigned int size_node = sizeof(m_nodes[0].parent) + 32 + sizeof(_weight) + sizeof(bool);
    f.write((char*)&nb_nodes, sizeof(nb_nodes));
    f.write((char*)&size_node, sizeof(size_node));
    f.write((char*)&m_k, sizeof(m_k));
    f.write((char*)&m_L, sizeof(m_L));
    f.write((char*)&m_scoring, sizeof(m_scoring));
    f.write((char*)&m_weighting, sizeof(m_weighting));
    for(size_t i=1; i<nb_nodes;i++) {
        const Node& node = m_nodes[i];
        f.write((char*)&node.parent, sizeof(node.parent));
//        f.write((char*)node.descriptor.data, F::L);
        f.write((char*)node.descriptor.data, 32);
        _weight = node.weight;
        f.write((char*)&_weight, sizeof(_weight));
        bool is_leaf = node.isLeaf();
        f.write((char*)&is_leaf, sizeof(is_leaf)); // i put this one at the end for alignement....
    }

    f.close();
}
// --------------------------------------------------------------------------


void Vocabulary::load ( const std::string &filename )
{
    cv::FileStorage fs ( filename.c_str(), cv::FileStorage::READ );
    if ( !fs.isOpened() ) throw std::string ( "Could not open file " ) + filename;

    this->load ( fs );
}

// --------------------------------------------------------------------------


void Vocabulary::save ( cv::FileStorage &f,
                        const std::string &name ) const
{
    /**
    // Format YAML:
    // vocabulary
    // {
    //   k:
    //   L:
    //   scoringType:
    //   weightingType:
    //   nodes
    //   [
    //     {
    //       nodeId:   //节点id
    //       parentId:  //父节点id
    //       weight:
    //       descriptor:
    //     }
    //   ]
    //   words
    //   [
    //     {
    //       wordId:
    //       nodeId:
    //     }
    //   ]
    // }
    //
    // The root node (index 0) is not included in the node vector
    //
     */

    f << name << "{";

    f << "k" << m_k;
    f << "L" << m_L;
    f << "scoringType" << m_scoring;
    f << "weightingType" << m_weighting;

    // tree
    f << "nodes" << "[";
    std::vector<NodeId> parents, children;
    std::vector<NodeId>::const_iterator pit;

    parents.push_back ( 0 ); // root

    while ( !parents.empty() )
    {
        NodeId pid = parents.back();
        parents.pop_back();

        const Node& parent = m_nodes[pid];
        children = parent.children;

        for ( pit = children.begin(); pit != children.end(); pit++ )
        {
            const Node& child = m_nodes[*pit];

            // save node data
            f << "{:";
            f << "nodeId" << ( int ) child.id;
            f << "parentId" << ( int ) pid;
            f << "weight" << ( double ) child.weight;
            f << "descriptor" << DescManip::toString ( child.descriptor );
            f << "}";

            // add to parent list
            if ( !child.isLeaf() )
            {
                parents.push_back ( *pit );
            }
        }
    }

    f << "]"; // nodes

    // words
    f << "words" << "[";

    for ( auto wit = m_words.begin(); wit != m_words.end(); wit++ )
    {
        WordId id = wit - m_words.begin();
        f << "{:";
        f << "wordId" << ( int ) id;
        f << "nodeId" << ( int ) ( *wit )->id;
        f << "}";
    }

    f << "]"; // words

    f << "}";

}


/**
//void Vocabulary::saveToTextFile(cv::FileStorage &f,
//                                const std::string &name) const
//{
//    fstream f;
//    f.open(filename.c_str(),ios_base::out);
//    f << m_k << " " << m_L << " " << " " << m_scoring << " " << m_weighting << endl;
//
//    for(size_t i=1; i<m_nodes.size();i++)
//    {
//        const Node& node = m_nodes[i];
//
//        f << node.parent << " ";
//        if(node.isLeaf())
//            f << 1 << " ";
//        else
//            f << 0 << " ";
//
//        f << DescManip::toString(node.descriptor) << " " << (double)node.weight << endl;
//    }
//
//    f.close();
//}
 */
// --------------------------------------------------------------------------


void Vocabulary::load ( const cv::FileStorage &fs,
                        const std::string &name )
{
    m_words.clear();
    m_nodes.clear();

    cv::FileNode fvoc = fs[name];

    m_k = ( int ) fvoc["k"];
    m_L = ( int ) fvoc["L"];
    m_scoring = ( ScoringType ) ( ( int ) fvoc["scoringType"] );
    m_weighting = ( WeightingType ) ( ( int ) fvoc["weightingType"] );

    createScoringObject();

    // nodes
    cv::FileNode fn = fvoc["nodes"];

    m_nodes.resize ( fn.size() + 1 ); // +1 to include root
    m_nodes[0].id = 0;

    for ( unsigned int i = 0; i < fn.size(); ++i )
    {
        NodeId nid = ( int ) fn[i]["nodeId"];
        NodeId pid = ( int ) fn[i]["parentId"];
        WordValue weight = ( WordValue ) fn[i]["weight"];
        std::string d = ( std::string ) fn[i]["descriptor"];

        m_nodes[nid].id = nid;
        m_nodes[nid].parent = pid;
        m_nodes[nid].weight = weight;
        m_nodes[pid].children.push_back ( nid );

        DescManip::fromString ( m_nodes[nid].descriptor, d );
    }

    // words
    fn = fvoc["words"];

    m_words.resize ( fn.size() );

    for ( unsigned int i = 0; i < fn.size(); ++i )
    {
        NodeId wid = ( int ) fn[i]["wordId"];
        NodeId nid = ( int ) fn[i]["nodeId"];

        m_nodes[nid].word_id = wid;
        m_words[wid] = &m_nodes[nid];
    }
}

// --------------------------------------------------------------------------

/**
 * Writes printable information of the vocabulary
 * @param os stream to write to
 * @param voc
 */

std::ostream& operator<< ( std::ostream &os,
                           const Vocabulary &voc )
{
    os << "Vocabulary: k = " << voc.getBranchingFactor()
       << ", L = " << voc.getDepthLevels()
       << ", Weighting = ";

    switch ( voc.getWeightingType() )
    {
    case TF_IDF:
        os << "tf-idf";
        break;
    case TF:
        os << "tf";
        break;
    case IDF:
        os << "idf";
        break;
    case BINARY:
        os << "binary";
        break;
    }

    os << ", Scoring = ";
    switch ( voc.getScoringType() )
    {
    case L1_NORM:
        os << "L1-norm";
        break;
    case L2_NORM:
        os << "L2-norm";
        break;
    case CHI_SQUARE:
        os << "Chi square distance";
        break;
    case KL:
        os << "KL-divergence";
        break;
    case BHATTACHARYYA:
        os << "Bhattacharyya coefficient";
        break;
    case DOT_PRODUCT:
        os << "Dot product";
        break;
    }

    os << ", Number of words = " << voc.size();

    return os;
}
}
